"""
wine_quality_kd.py
------------------
Teacher–student knowledge distillation with VQCs on the UCI Wine-Quality data.

• Teacher  : 10-qubit EfficientSU2 (reps = 2)
• Student  :  6-qubit EfficientSU2 (reps = 1)
• Features : 11 physicochemical attributes → StandardScaler → PCA(10)
• Encoding : angle encoding, θ = PCA component rescaled to [0, π]
• Distill  : teacher generates hard pseudo-labels for each training sample
"""

# ---------------------------------------------------------------
# 0. Imports
# ---------------------------------------------------------------
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split

from qiskit_aer import AerSimulator
from qiskit.circuit.library import EfficientSU2
from qiskit_machine_learning.neural_networks import SamplerQNN
from qiskit_machine_learning.algorithms import VQC
from qiskit.utils import QuantumInstance


# ---------------------------------------------------------------
# 1. Distillation “server”
# ---------------------------------------------------------------
class DistillationServer:
    def __init__(
        self,
        num_teacher_qubits: int = 10,
        num_student_qubits: int = 6,
        pca_components: int = 10,
        seed: int = 123,
    ):
        self.seed = seed
        self.n_classes = 6                       # wine quality scores 3‒8
        self.num_teacher_qubits = num_teacher_qubits
        self.num_student_qubits = num_student_qubits
        self.pca_components = pca_components
        self._rng()

        self._load_data()
        self._build_teacher()
        self._build_student()

    # 1.1  reproducibility
    def _rng(self):
        np.random.seed(self.seed)

    # 1.2  download + preprocess UCI wine-quality CSVs
    def _load_data(self):
        print("⇨ Loading Wine-Quality data …")

        red = pd.read_csv(
            "https://archive.ics.uci.edu/ml/machine-learning-databases/"
            "wine-quality/winequality-red.csv",
            sep=";",
        )
        white = pd.read_csv(
            "https://archive.ics.uci.edu/ml/machine-learning-databases/"
            "wine-quality/winequality-white.csv",
            sep=";",
        )
        data = pd.concat([red, white], ignore_index=True)

        X = data.drop(columns=["quality"]).values.astype(np.float32)
        y = data["quality"].values.astype(int)          # 3–8

        # standardise → PCA
        X = StandardScaler().fit_transform(X)
        pca = PCA(n_components=self.pca_components, random_state=self.seed)
        X = pca.fit_transform(X)

        # map to [0, π] for angle encoding
        X = np.pi * (X - X.min()) / (X.max() - X.min() + 1e-12)

        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(
            X, y, test_size=0.2, stratify=y, random_state=self.seed
        )
        print(f"   • Train size: {self.X_train.shape[0]}")
        print(f"   • Test  size: {self.X_test.shape[0]}")

    # 1.3  angle-encoding circuit generator
    @staticmethod
    def _feature_map(x, num_qubits):
        from qiskit import QuantumCircuit

        qc = QuantumCircuit(num_qubits)
        for i, theta in enumerate(x[:num_qubits]):
            qc.ry(theta, i)
        return qc

    # 1.4  build teacher VQC
    def _build_teacher(self):
        ansatz = EfficientSU2(self.num_teacher_qubits, reps=2)
        self.teacher = VQC(
            feature_map=lambda x: self._feature_map(x, self.num_teacher_qubits),
            ansatz=ansatz,
            optimizer="COBYLA",
            quantum_instance=AerSimulator(seed_simulator=self.seed),
            num_classes=self.n_classes,
        )

    # 1.5  build student VQC (smaller, shallower)
    def _build_student(self):
        ansatz = EfficientSU2(self.num_student_qubits, reps=1)
        self.student = VQC(
            feature_map=lambda x: self._feature_map(x, self.num_student_qubits),
            ansatz=ansatz,
            optimizer="COBYLA",
            quantum_instance=AerSimulator(seed_simulator=self.seed),
            num_classes=self.n_classes,
        )

    # 1.6  stage 1 – train teacher on ground-truth labels
    def train_teacher(self):
        print("\n⇨ Training teacher VQC …")
        self.teacher.fit(self.X_train, self.y_train)

    # 1.7  stage 2 – teacher generates pseudo-labels (hard)
    def _pseudo_labels(self, X):
        return self.teacher.predict(X)

    # 1.8  stage 3 – train student on teacher labels
    def train_student(self):
        print("\n⇨ Generating pseudo-labels …")
        pseudo = self._pseudo_labels(self.X_train)
        print("⇨ Training student VQC on pseudo-labels …")
        self.student.fit(self.X_train, pseudo)

    # 1.9  accuracy helper
    @staticmethod
    def _acc(model, X, y):
        return (model.predict(X) == y).mean()

    def report(self):
        print("\n=== Test accuracy ===")
        print(f"Teacher : {self._acc(self.teacher,  self.X_test, self.y_test):.3f}")
        print(f"Student : {self._acc(self.student,  self.X_test, self.y_test):.3f}")


# ---------------------------------------------------------------
# 2. Main
# ---------------------------------------------------------------
if __name__ == "__main__":
    server = DistillationServer()
    server.train_teacher()
    server.train_student()
    server.report()
